1 Domain problem formulation

What is the real-world question? This could be hypothesis-driven or discovery-based.


Why is this question interesting and important? What are the implications of better understanding this data?


Briefly describe any background information necessary to understand this problem.


Briefly describe how this question can be answered in the context of a model or analysis.


Outline the rest of the report/analysis.


2 Data

What is the data under investigation? Provide a brief overview/description of the data.


Describe how your data connects to the domain problem.


2.1 Data Collection

How was the data collected or generated (including details on the experimental design)? Be as transparent as possible so that conclusions made from this data are not misinterpreted down the road.


Describe any limitations when using the data to answer the domain problem of interest.


Where is the data stored, and how can it be accessed by others (if applicable)?


2.2 Data splitting

TODO: add advice for possible data splits

Decide on the proportion of data in each split.

Decide on the “how” to split the data (e.g., random sampling, stratified sampling, etc.), and explain why this is a reasonable way to split the data.


Split the data into a training, validation, and test set.

# TODO: pick more interesting datasets
data = iris
breakdown = c(train = 0.6, validate = 0.2, test = 0.2)
labels = sample(cut(
  seq(nrow(data)), 
  nrow(data)*cumsum(c(0,breakdown)),
  labels = names(breakdown)
))
data_split = split(data, labels)

Xtrain <- data_split$train %>% dplyr::select(-Species)
Xvalid <- data_split$validate %>% dplyr::select(-Species)
Xtest <- data_split$test %>% dplyr::select(-Species)
ytrain <- data_split$train$Species
yvalid <- data_split$validate$Species
ytest <- data_split$test$Species

######FIX need "splits" for tuning code
X <- iris %>% dplyr::select(-Species)
y <- iris$Species
data_df <- dplyr::bind_cols(.y = y, X)
splits <- rsample::initial_split(data_df)
train_df <- rsample::training(splits)
valid_df <- rsample::testing(splits)

Provide summary statistics and/or figures of the three data sets to illustrate how similar (or different) they are.


2.2.1 Data Split Overview

X Comparison

train = melt(data_split$train)
train$split = "train"
validation = melt(data_split$validate)
validation$split = "validation"
test = melt(data_split$test)
test$split = "test"
data_all = rbind(train,validation,test)
data_all$split = factor(data_all$split,levels=c("train","validation","test"))
fnt = 12
fnt2=12
#ggplot(data_all, aes(x=variable, y=value)) + 
 # geom_boxplot() + 
  ggplot(data_all, aes(x=value, fill=variable)) + geom_density(alpha=.3)+
    theme_bw() + 
  facet_wrap(. ~ split)+
  theme(axis.title.x = element_text(size=fnt,face="bold")) + 
  #theme(axis.title.y = element_text(size=fnt,face="bold")) +
  theme(axis.title.y = element_blank())+
  theme(strip.text.x = element_text(size = fnt2,face="bold")) +
  #ylab("True Positives") +
  xlab("") + 
  theme(legend.title = element_blank())+#element_text(size=fnt,face="bold")) +
  theme(legend.text=element_text(size=fnt2,face="bold"))

Y Comparison

plot_data <- data_all %>% dplyr::select(c(Species,split)) %>%
      count(Species,split) %>% 
      group_by(split) %>% 
      mutate(percent = n/sum(n))
ggplot(plot_data, aes(x = split, y = percent, fill = Species)) + 
  theme_bw() + 
      geom_col(position = "fill") +
  theme(axis.title.x = element_text(size=fnt,face="bold")) + 
  #theme(axis.title.y = element_text(size=fnt,face="bold")) +
  theme(axis.title.y = element_blank())+
  theme(strip.text.x = element_text(size = fnt2,face="bold")) +
  #ylab("True Positives") +
  xlab("") + 
  theme(legend.title = element_blank())+#element_text(size=fnt,face="bold")) +
  theme(legend.text=element_text(size=fnt2,face="bold"))

2.3 Data Cleaning and Preprocessing

What steps were taken to clean the data? More importantly, why was the data cleaned in this way?

Discuss all inconsistencies, problems, oddities in the data (e.g., missing data, errors in data, outliers, etc.).

Record your preprocessing steps in a way such that if someone else were to reproduce your analysis, they could easily replicate and understand your steps.

It can be helpful to include relevant plots that explain/justify the choices that were made when cleaning the data.

If more than one preprocessing pipeline is reasonable, examine the impacts of these alternative preprocessing pipelines on the final data results.

Again, be as transparent as possible. This allows others to make their own educated decisions on how best to preprocess the data.


2.4 Data Exploration

TODO: Add drag and drop feature in shiny version for other images

The main goal of this section is to give the reader a feel for what the data “looks like” at a basic level.

Provide plots that summarize the data and perhaps even plots that convey some smaller findings which ultimately motivate the main findings.

Provide additional plots representing remaining oddities after pre-processing if applicable.

Add summary statistics in accompanying tables (or in figures) for quick comparisons.


2.4.1 Data Overview

#> Number of samples: 90
#> Number of features: 4
#> Number of NAs in training y: 0
#> Number of NAs in training X: 0
#> Number of columns in training X with NAs: 0
#> Number of constant columns in training X: 0

Summary Tables

data_types(Xtrain = Xtrain, ytrain = ytrain)
dt_ls <- data_summary(Xtrain = Xtrain, ytrain = ytrain, digits = 2, sigfig = F)
for (dt_name in names(dt_ls)) {
  subchunkify(dt_ls[[dt_name]], i = chunk_idx, other_args = "results='asis'")
  chunk_idx <- chunk_idx + 1
}

X distribution

# plot X distribution
plot_X_distribution(Xtrain, "density")

Y distribution

# plot y distribution
plot_y_distribution(ytrain, "bar")

Feature Correlation

# correlation heatmap
plotCorHeatmap(X = Xtrain, cor.type = "pearson", clust = TRUE, text.size = 8)

Feature Pair Plots

# pair plots
col_ids <- 1:min(ncol(Xtrain), 6)
plotPairs(data = Xtrain, columns = col_ids, 
          color = ytrain, color.label = "y")

Marginal Association Plots

caret::featurePlot(x = Xtrain,
                   y = ytrain,
                   plot = if (is.factor(ytrain)) "box" else "scatter",
                   # strip = strip.custom(par.strip.text = list(cex = .7)),
                   scales = list(x = list(relation = "free"), 
                                 y = list(relation = "free")))

PCA Plots

# dimension reduction plots
plotPCA(X = Xtrain, npcs = 3, color = ytrain, color.label = "y",
        center = T, scale = FALSE)$plot

For inspiration: Shiny App

3 Prediction Modeling

TODO: add advice on which models to select and why

Discuss the prediction methods under consideration, and explain why these methods were chosen.


Discuss the accuracy metrics under consideration, and explain why these metrics were chosen.


Note: there should be multiple methods and metrics under consideration to paint a more holistic picture of the data. At least one method should be a baseline, common approach that may not be optimal for the problem setting, but serves as a helpful comparison.

3.1 Prediction check

Carry out the prediction pipeline, outlined above.

  • Fit prediction methods on training data.
  • Evaluate prediction methods on validation data.
  • Compare results, and filter out poor models.


caret

# how to do cross validation
trcontrol <- caret::trainControl(
  method = "cv",
  number = 5,
  classProbs = if (is.factor(ytrain)) TRUE else FALSE,
  summaryFunction = caret::defaultSummary,
  allowParallel = FALSE,
  verboseIter = FALSE
)

response <- "raw"
model_list <- list(
  ranger = list(tuneGrid = expand.grid(mtry = seq(sqrt(ncol(Xtrain)), 
                                                  ncol(Xtrain) / 3,
                                                  length.out = 3),
                                       splitrule = "gini",
                                       min.node.size = 1),
                importance = "impurity",
                num.threads = 1),
  xgbTree = list(tuneGrid = expand.grid(nrounds = c(10, 25, 50, 100, 150),
                                        max_depth = c(3, 6),
                                        colsample_bytree = 0.33,
                                        eta = c(0.1, 0.3),
                                        gamma = 0,
                                        min_child_weight = 1,
                                        subsample = 0.6),
                 nthread = 1)
)

model_fits <- list()
model_preds <- list()
model_errs <- list()
model_vimps <- list()
for (model_name in names(model_list)) {
  mod <- model_list[[model_name]]
  if (identical(mod, list())) {
    mod <- NULL
  }
  mod_fit <- do.call(caret::train, args = c(list(x = as.data.frame(Xtrain),
                                                 y = ytrain,
                                                 trControl = trcontrol,
                                                 method = model_name),
                                            mod))
  model_fits[[model_name]] <- mod_fit
  model_preds[[model_name]] <- predict(mod_fit, as.data.frame(Xvalid),
                                       type = response)
  model_errs[[model_name]] <- caret::postResample(
    pred = model_preds[[model_name]], obs = yvalid
  )
  model_vimps[[model_name]] <- caret::varImp(mod_fit)
}
model_preds <- dplyr::bind_rows(model_preds, .id = "model")
model_errs <- dplyr::bind_rows(model_errs, .id = "model")
model_vimps <- purrr::map_dfr(model_vimps, 
                              ~.x[["importance"]] %>% 
                                tibble::rownames_to_column("variable"),
                              .id = "model")

tidymodels

# TODO: add code for tuning parameters
mod_recipe <- recipes::recipe(.y ~., data = splits)

# for classification
rf_model <- parsnip::rand_forest() %>%
  parsnip::set_args(mtry = tune::tune()) %>%
  parsnip::set_engine("ranger", importance = "impurity") %>%
  parsnip::set_mode("classification")
rf_grid <- tidyr::crossing(mtry = 1:4)

svm_model <- parsnip::svm_rbf() %>%
  parsnip::set_engine("kernlab") %>%
  parsnip::set_mode("classification")

knn_model <- parsnip::nearest_neighbor() %>%
   parsnip::set_args(neighbors = tune(), weight_func = tune()) %>% 
   parsnip::set_engine("kknn") %>% 
   parsnip::set_mode("classification")

# models <- workflowsets::workflow_set(
#   preproc = list(Base = mod_recipe),
#   models = list(RF = rf_model, SVM = svm_model, KNN = knn_model),
#   cross = TRUE
# ) %>%
#   workflowsets::option_add(grid = rf_grid, id = "Base_RF")
# model_fits <- workflowsets::workflow_map(
#   object = models,
#   fn = "tune_grid"
# )

model_list <- list(RF = list(model = rf_model,
                             grid = rf_grid), 
                   SVM = list(model = svm_model,
                              grid = NULL), 
                   KNN = list(model = knn_model,
                              grid = 4))
model_fits <- list()
model_preds <- list()
model_errs <- list()
model_vimps <- list()
for (model_name in names(model_list)) {
  mod <- model_list[[model_name]]$model
  grid <- model_list[[model_name]]$grid
  if (!is.null(grid)) {
    mod_fit <- workflows::workflow() %>%
      workflows::add_recipe(mod_recipe) %>%
      workflows::add_model(mod)
    best_params <- mod_fit %>%
      tune::tune_grid(resamples = rsample::vfold_cv(train_df), 
                      grid = grid) %>%
      tune::select_best(metric = "accuracy")
    mod_fit <- mod_fit %>%
      tune::finalize_workflow(best_params) %>%
      tune::last_fit(splits)
  } else {
    mod_fit <- workflows::workflow() %>%
      workflows::add_recipe(mod_recipe) %>%
      workflows::add_model(mod) %>%
      tune::last_fit(splits)
  }
  model_fits[[model_name]] <- mod_fit
  model_preds[[model_name]] <- mod_fit %>%
    tune::collect_predictions()
  model_errs[[model_name]] <- mod_fit %>%
    tune::collect_metrics()
  model_vimps[[model_name]] <- tryCatch({
    # model-specific variable importance
    mod_fit %>%
      workflows::extract_fit_parsnip() %>%
      vip::vi()
  }, error = function(e) {
    # model-agnostic permutation variable importance
    mod_fit %>%
      workflows::extract_fit_parsnip() %>%
      vip::vi(method = "permute", train = train_df, target = ".y",
              feature_names = setdiff(colnames(train_df), ".y"), 
              pred_wrapper = predict, metric = "accuracy")
  })
}
model_preds <- dplyr::bind_rows(model_preds, .id = "model")
model_errs <- dplyr::bind_rows(model_errs, .id = "model")
model_vimps <- dplyr::bind_rows(model_vimps, .id = "model")

h20

library(h2o)
h2o.init(nthreads = 1)
iris.hex <- as.h2o(iris)
splits <- h2o.splitFrame(data = iris.hex,
                         ratios = c(0.8))
train_df <- splits[[1]]
valid_df <- splits[[2]]

model_list <- list(randomForest = list(ntrees = 500), 
                   xgboost = list())

model_fits <- list()
model_preds <- list()
model_errs <- list()
model_vimps <- list()
for (model_name in names(model_list)) {
  mod <- model_list[[model_name]]
  if (identical(mod, list())) {
    mod <- NULL
  }
  mod_fit <- do.call(paste0("h2o.", model_name),
                     args = c(list(x = colnames(Xtrain),
                                   y = "Species",
                                   training_frame = train_df,
                                   model_id = model_name),
                              mod))
  model_fits[[model_name]] <- mod_fit
  model_preds[[model_name]] <- h2o.predict(mod_fit, valid_df)
  model_errs[[model_name]] <- h2o.performance(mod_fit, valid_df)
  model_vimps[[model_name]] <- h2o.varimp(mod_fit)
}
model_preds <- purrr::map_dfr(model_preds, ~attr(.x, "data"), .id = "model")
model_errs <- purrr::map_dfr(
  model_errs, 
  function(err) {
    rm_objs <- c("model", "model_checksum", "frame", "frame_checksum",
                 "description", "scoring_time", "predictions")
    simChef:::simplify_tibble(simChef:::list_to_tibble_row(
      err@metrics[setdiff(names(err@metrics), rm_objs)]
    ))
  }, 
  .id = "model"
)
model_vimps <- dplyr::bind_rows(model_vimps, .id = "model")

3.2 Stability check

Taking the prediction methods that pass the prediction check, perform stability analysis.

  • Specify and justify the appropriate data perturbation(s).
  • Re-fit the prediction methods on these perturbed data sets.
  • Evaluate prediction methods on validation data.
  • Assess stability across the data perturbations as well as across the various methods.
  • Filter out poor models where necessary and interpret stability results.


TODO: Add results for tidymodels and h20 in addition to caret

caret

nrep = 5 # increase for better stability measures when not testing code
model_preds_b <- list()
model_errs_b <- list()
model_vimps_b <- list()
for (b in 1:nrep){
  # bootstrap training
  bootstrap = sample(1:nrow(Xtrain),nrow(Xtrain))
  Xtrain_b = Xtrain[bootstrap,]
  ytrain_b = ytrain[bootstrap]
  
  # how to do cross validation
  trcontrol <- caret::trainControl(
    method = "cv",
    number = 5,
    classProbs = if (is.factor(ytrain_b)) TRUE else FALSE,
    summaryFunction = caret::defaultSummary,
    allowParallel = FALSE,
    verboseIter = FALSE
  )
  
  response <- "raw"
  model_list <- list(
    ranger = list(tuneGrid = expand.grid(mtry = seq(sqrt(ncol(Xtrain_b)), 
                                                    ncol(Xtrain_b) / 3,
                                                    length.out = 3),
                                         splitrule = "gini",
                                         min.node.size = 1),
                  importance = "impurity",
                  num.threads = 1),
    xgbTree = list(tuneGrid = expand.grid(nrounds = c(10, 25, 50, 100, 150),
                                          max_depth = c(3, 6),
                                          colsample_bytree = 0.33,
                                          eta = c(0.1, 0.3),
                                          gamma = 0,
                                          min_child_weight = 1,
                                          subsample = 0.6),
                   nthread = 1)
  )
  
  model_fits_temp <- list()
  model_preds_temp <- list()
  model_errs_temp <- list()
  model_vimps_temp <- list()
  for (model_name in names(model_list)) {
    mod <- model_list[[model_name]]
    if (identical(mod, list())) {
      mod <- NULL
    }
    mod_fit <- do.call(caret::train, args = c(list(x = as.data.frame(Xtrain_b),
                                                   y = ytrain_b,
                                                   trControl = trcontrol,
                                                   method = model_name),
                                              mod))
    model_fits_temp[[model_name]] <- mod_fit
    model_preds_temp[[model_name]] <- predict(mod_fit, as.data.frame(Xvalid),
                                         type = response)
    model_errs_temp[[model_name]] <- caret::postResample(
      pred = model_preds_temp[[model_name]], obs = yvalid
    )
    model_vimps_temp[[model_name]] <- caret::varImp(mod_fit)
  }
  
  model_preds_temp <- dplyr::bind_rows(model_preds_temp, .id = "model")
  model_errs_temp <- dplyr::bind_rows(model_errs_temp, .id = "model")
  model_vimps_temp <- purrr::map_dfr(model_vimps_temp, 
                                ~.x[["importance"]] %>% 
                                  tibble::rownames_to_column("variable"),
                                .id = "model")
  model_preds_b[[b]] <- model_preds_temp
  model_errs_b[[b]] <- model_errs_temp
  model_vimps_b[[b]] <- model_vimps_temp
}

model_errs_b <- bind_rows(model_errs_b, .id = "column_label")
model_errs_b_tmp <- model_errs_b
model_errs_b <- model_errs_b %>%
                group_by(model) %>%
                summarise(mean_accuracy = mean(Accuracy), sd_accuracy =  sd(Accuracy),
                          mean_kappa = mean(Kappa), sd_kappa = sd(Kappa))

#### tidymodels {.unnumbered}

nrep = 1 # increase for better stability measures when not testing code
model_preds_b <- list()
model_errs_b <- list()
model_vimps_b <- list()
for (b in 1:nrep){
  # TODO: add code for tuning parameters
  mod_recipe <- recipes::recipe(.y ~., data = splits)
  
  # for classification
  rf_model <- parsnip::rand_forest() %>%
    parsnip::set_args(mtry = tune::tune()) %>%
    parsnip::set_engine("ranger", importance = "impurity") %>%
    parsnip::set_mode("classification")
  rf_grid <- tidyr::crossing(mtry = 1:4)
  
  svm_model <- parsnip::svm_rbf() %>%
    parsnip::set_engine("kernlab") %>%
    parsnip::set_mode("classification")
  
  knn_model <- parsnip::nearest_neighbor() %>%
     parsnip::set_args(neighbors = tune(), weight_func = tune()) %>% 
     parsnip::set_engine("kknn") %>% 
     parsnip::set_mode("classification")
  
  # models <- workflowsets::workflow_set(
  #   preproc = list(Base = mod_recipe),
  #   models = list(RF = rf_model, SVM = svm_model, KNN = knn_model),
  #   cross = TRUE
  # ) %>%
  #   workflowsets::option_add(grid = rf_grid, id = "Base_RF")
  # model_fits <- workflowsets::workflow_map(
  #   object = models,
  #   fn = "tune_grid"
  # )
  
  model_list <- list(RF = list(model = rf_model,
                               grid = rf_grid), 
                     SVM = list(model = svm_model,
                                grid = NULL), 
                     KNN = list(model = knn_model,
                                grid = 4))
  model_fits_temp <- list()
  model_preds_temp <- list()
  model_errs_temp <- list()
  model_vimps_temp <- list()
  for (model_name in names(model_list)) {
    mod <- model_list[[model_name]]$model
    grid <- model_list[[model_name]]$grid
    if (!is.null(grid)) {
      mod_fit <- workflows::workflow() %>%
        workflows::add_recipe(mod_recipe) %>%
        workflows::add_model(mod)
      best_params <- mod_fit %>%
        tune::tune_grid(resamples = rsample::vfold_cv(train_df), 
                        grid = grid) %>%
        tune::select_best(metric = "accuracy")
      mod_fit <- mod_fit %>%
        tune::finalize_workflow(best_params) %>%
        tune::last_fit(splits)
    } else {
      mod_fit <- workflows::workflow() %>%
        workflows::add_recipe(mod_recipe) %>%
        workflows::add_model(mod) %>%
        tune::last_fit(splits)
    }
    model_fits_temp[[model_name]] <- mod_fit
    model_preds_temp[[model_name]] <- mod_fit %>%
      tune::collect_predictions()
    model_errs_temp[[model_name]] <- mod_fit %>%
      tune::collect_metrics()
    model_vimps_temp[[model_name]] <- tryCatch({
      # model-specific variable importance
      mod_fit %>%
        workflows::extract_fit_parsnip() %>%
        vip::vi()
    }, error = function(e) {
      # model-agnostic permutation variable importance
      mod_fit %>%
        workflows::extract_fit_parsnip() %>%
        vip::vi(method = "permute", train = train_df, target = ".y",
                feature_names = setdiff(colnames(train_df), ".y"), 
                pred_wrapper = predict, metric = "accuracy")
    })
  }
  model_preds_temp <- dplyr::bind_rows(model_preds_temp, .id = "model")
  model_errs_temp <- dplyr::bind_rows(model_errs_temp, .id = "model")
  model_vimps_temp <- dplyr::bind_rows(model_vimps_temp, .id = "model")

  model_preds_b[[b]] <- model_preds_temp
  model_errs_b[[b]] <- model_errs_temp
  model_vimps_b[[b]] <- model_vimps_temp
}

model_errs_b <- bind_rows(model_errs_b, .id = "column_label")

#FIX summary
# model_errs_b <- model_errs_b %>%
#                 group_by(model) %>%
#                 summarise(mean_accuracy = mean(Accuracy), sd_accuracy =  sd(Accuracy),
#                           mean_kappa = mean(Kappa), sd_kappa = sd(Kappa))

h20

3.3 Interpretability

For the models that pass the prediction and stability checks,

  • Extract the important features in the predictive models that are stable across both data and model perturbations. Determining the importance of a feature can be method dependent.


3.3.1 Without stability

prettyDT(model_vimps, digits = 2, sigfig = F, caption = "Variable Importances")
# bar plot
vip::vip(model_vimps,
         num_features = 10,
         geom = "col") +
  prettyGGplotTheme()

# scatter plot
plt <- model_vimps %>%
  tidyr::pivot_wider(names_from = "model", values_from = "Importance") %>%
  plotPairs(columns = which(!(colnames(.) %in% "Variable"))) +
  ggplot2::theme_bw() 
plotly::ggplotly(plt)

3.3.2 With stability

TODO

4 Main Results

Interpret and summarize the prediction and stability results.


Evaluate pipeline on test data.

Summarize test set prediction and/or interpretability results.


caret

# how to do cross validation
trcontrol <- caret::trainControl(
  method = "cv",
  number = 5,
  classProbs = if (is.factor(ytrain)) TRUE else FALSE,
  summaryFunction = caret::defaultSummary,
  allowParallel = FALSE,
  verboseIter = FALSE
)

response <- "raw"
model_list <- list(
  ranger = list(tuneGrid = expand.grid(mtry = seq(sqrt(ncol(Xtrain)), 
                                                  ncol(Xtrain) / 3,
                                                  length.out = 3),
                                       splitrule = "gini",
                                       min.node.size = 1),
                importance = "impurity",
                num.threads = 1),
  xgbTree = list(tuneGrid = expand.grid(nrounds = c(10, 25, 50, 100, 150),
                                        max_depth = c(3, 6),
                                        colsample_bytree = 0.33,
                                        eta = c(0.1, 0.3),
                                        gamma = 0,
                                        min_child_weight = 1,
                                        subsample = 0.6),
                 nthread = 1)
)

model_fits <- list()
model_preds <- list()
model_errs <- list()
model_vimps <- list()
for (model_name in names(model_list)) {
  mod <- model_list[[model_name]]
  if (identical(mod, list())) {
    mod <- NULL
  }
  mod_fit <- do.call(caret::train, args = c(list(x = as.data.frame(Xtrain),
                                                 y = ytrain,
                                                 trControl = trcontrol,
                                                 method = model_name),
                                            mod))
  model_fits[[model_name]] <- mod_fit
  model_preds[[model_name]] <- predict(mod_fit, as.data.frame(Xtest),
                                       type = response)
  model_errs[[model_name]] <- caret::postResample(
    pred = model_preds[[model_name]], obs = ytest
  )
  model_vimps[[model_name]] <- caret::varImp(mod_fit)
}
model_preds <- dplyr::bind_rows(model_preds, .id = "model")
model_errs <- dplyr::bind_rows(model_errs, .id = "model")
model_vimps <- purrr::map_dfr(model_vimps, 
                              ~.x[["importance"]] %>% 
                                tibble::rownames_to_column("variable"),
                              .id = "model")

tidymodels

# TODO: add code for tuning parameters
mod_recipe <- recipes::recipe(.y ~., data = splits)

# for classification
rf_model <- parsnip::rand_forest() %>%
  parsnip::set_args(mtry = tune::tune()) %>%
  parsnip::set_engine("ranger", importance = "impurity") %>%
  parsnip::set_mode("classification")
rf_grid <- tidyr::crossing(mtry = 1:4)

svm_model <- parsnip::svm_rbf() %>%
  parsnip::set_engine("kernlab") %>%
  parsnip::set_mode("classification")

knn_model <- parsnip::nearest_neighbor() %>%
   parsnip::set_args(neighbors = tune(), weight_func = tune()) %>% 
   parsnip::set_engine("kknn") %>% 
   parsnip::set_mode("classification")

# models <- workflowsets::workflow_set(
#   preproc = list(Base = mod_recipe),
#   models = list(RF = rf_model, SVM = svm_model, KNN = knn_model),
#   cross = TRUE
# ) %>%
#   workflowsets::option_add(grid = rf_grid, id = "Base_RF")
# model_fits <- workflowsets::workflow_map(
#   object = models,
#   fn = "tune_grid"
# )

model_list <- list(RF = list(model = rf_model,
                             grid = rf_grid), 
                   SVM = list(model = svm_model,
                              grid = NULL), 
                   KNN = list(model = knn_model,
                              grid = 4))
model_fits <- list()
model_preds <- list()
model_errs <- list()
model_vimps <- list()
for (model_name in names(model_list)) {
  mod <- model_list[[model_name]]$model
  grid <- model_list[[model_name]]$grid
  if (!is.null(grid)) {
    mod_fit <- workflows::workflow() %>%
      workflows::add_recipe(mod_recipe) %>%
      workflows::add_model(mod)
    best_params <- mod_fit %>%
      tune::tune_grid(resamples = rsample::vfold_cv(train_df), 
                      grid = grid) %>%
      tune::select_best(metric = "accuracy")
    mod_fit <- mod_fit %>%
      tune::finalize_workflow(best_params) %>%
      tune::last_fit(splits)
  } else {
    mod_fit <- workflows::workflow() %>%
      workflows::add_recipe(mod_recipe) %>%
      workflows::add_model(mod) %>%
      tune::last_fit(splits)
  }
  model_fits[[model_name]] <- mod_fit
  model_preds[[model_name]] <- mod_fit %>%
    tune::collect_predictions()
  model_errs[[model_name]] <- mod_fit %>%
    tune::collect_metrics()
  model_vimps[[model_name]] <- tryCatch({
    # model-specific variable importance
    mod_fit %>%
      workflows::extract_fit_parsnip() %>%
      vip::vi()
  }, error = function(e) {
    # model-agnostic permutation variable importance
    mod_fit %>%
      workflows::extract_fit_parsnip() %>%
      vip::vi(method = "permute", train = train_df, target = ".y",
              feature_names = setdiff(colnames(train_df), ".y"), 
              pred_wrapper = predict, metric = "accuracy")
  })
}
model_preds <- dplyr::bind_rows(model_preds, .id = "model")
model_errs <- dplyr::bind_rows(model_errs, .id = "model")
model_vimps <- dplyr::bind_rows(model_vimps, .id = "model")

h20

5 Post hoc analysis

Move beyond the global prediction accuracy metrics and dive deeper into individual-level predictions for the validation and/or test set, i.e., provide a more “local” analysis.

  • Examine any points that had poor predictions.
  • Examine differences between prediction methods.


TODO: Tiffany - Add examples with interesting observations of prediction accuracy metrics so the user knows what to look for.

model_preds %>%
  tidyr::pivot_wider(names_from = "model", values_from = ".pred_setosa", 
                     id_cols = c("id", ".row")) %>%
  plotPairs(columns = which(!(colnames(.) %in% c("id", ".row"))))

6 Conclusions

Reiterate main findings, note any caveats, and clearly translate findings/analysis back to the domain problem context.


Bibliography